import numpy as np
import pickle
import matplotlib.pyplot as plt
import os


def getdope(ams, u=8):
    # only care about noblow
    ams = ams[0]

    # determine number of sites, same for all in scan
    mask = np.copy(ams[0])
    mask[mask != 0] = 1

    sds = []

    # determine number of singles
    for am in ams:
        arr = np.copy(am)
        arr[arr != 1] = 0

        sds.append(float(np.sum(arr))/np.sum(mask))

    sd = np.mean(sds)
    # sde = np.std(sds)

    if u == 8:
        doping = 1.22 * (0.905 - sd)
        # dopinge = 1.22 * sde
    else:
        raise NotImplementedError('Doping estimate not implemented for U/t != 8')

    return doping, None


def postselect(ams, ps):
    sms = []
    ms = []

    for am in ams:
        am = np.array(am)
        m = am.copy()

        # cut out smaller window
        roi = np.where(m != 0)
        # find corners of big roi
        xMin = min(np.unique(roi[0]))
        xMax = max(np.unique(roi[0]))
        yMin = min(np.unique(roi[1]))
        yMax = max(np.unique(roi[1]))

        # go through all window positions and choose the one with highest m_z
        m_temps = []
        sms_temps = []
        for x1 in range(xMin, xMax-5):
            for y1 in range(yMin, yMax-5):
                m1 = m.copy()
                mask = np.zeros_like(am)
                try:
                    # mask[x1][y1+1:y1+5] = [1, 1, 1, 1]
                    # mask[x1+5] = mask[x1]
                    # mask[x1+1][y1:y1+6] = [1, 1, 1, 1, 1, 1]
                    # mask[x1+2] = mask[x1+1]
                    # mask[x1+3] = mask[x1+1]
                    # mask[x1+4] = mask[x1+1]
                    mask[x1+0, y1+2:y1+5] = 1
                    mask[x1+1, y1+1:y1+6] = 1
                    mask[x1+2, y1+0:y1+7] = 1
                    mask[x1+3] = mask[x1+2]
                    mask[x1+4] = mask[x1+2]
                    mask[x1+5] = mask[x1+1]
                    mask[x1+6] = mask[x1+0]
                except (ValueError, IndexError):
                    pass
                roiCutout = np.where(mask == 0)
                m1[roiCutout] = 0
                totalSites_mask = sum(sum(abs(mask)))
                totalSites = sum(sum(abs(m1)))

                if totalSites == totalSites_mask:  # only take position if window completely in big ROI
                    # get staggered magnetization
                    sm = 0
                    for i in range(x1-1, min(x1+12, am.shape[0])):
                        for j in range(y1-1, min(y1+12, am.shape[1])):
                            sm += (-1)**(i+j) * m1[i, j]
                    m_temps.append(m1)
                    sms_temps.append(abs(sm))

        m = m_temps[np.argmax(sms_temps)]  # This window position has the highest m_z!
        sm = max(sms_temps)  # This is its m_z value!

        # save the picture and normalized m_z
        totalSites = np.sum(abs(m))
        sms.append(1.0*sm/totalSites)
        ms.append(m)

    # sort the shots by staggered magnetization in descending order
    indices = sorted(range(len(sms)), key=lambda k: -abs(sms[k]))
    sms = [sms[i] for i in indices]
    ms = [ms[i] for i in indices]

    print '-------------- shots total: ' + str(len(sms)) + ' --------------'
    numberShots = int(ps*len(sms)+0.5)
    print 'shots taken: ' + str(numberShots)
    print 'lowest stagg mag taken: ' + str(sms[numberShots-1])

    return ms[:int(ps*len(sms)+0.5)]


def find_strings(ams, mt=0):
    '''
    :param ams: 3D array of atom matrices. 1: atom; 0: n/a; -1: no atom
    :param mt: float, magnetization threshold
    :return happiness_count: list of site happinesses
    :return string_lengths: list, pattern number/length of each string
    :return countedstrings: list, number of strings found in each shot
    :return numsites: float, number of sites in ROI
    '''

    happiness_count = []
    string_lengths = []
    countedstrings = []
    for i, am in enumerate(ams):
        single_mag = 0
        mask = np.zeros_like(am) + am != 0
        for x in range(am.shape[0]):
            for y in range(am.shape[1]):
                single_mag += am[x][y]*(-1)**(x+y)
        single_mag = float(single_mag)/np.sum(mask)

        if np.abs(single_mag) > min(float(mt), 0.9):
            happiness_count.append([])
            happiness_map = np.zeros_like(am)
            string_states = {}
            for x in range(am.shape[0]-1):
                for y in range(am.shape[1]-1):
                    if np.asarray(mask)[x-1, y] and np.asarray(mask)[x+1, y] and \
                       np.asarray(mask)[x, y-1] and np.asarray(mask)[x, y+1]:
                        happiness_i = 0
                        happiness_i += np.asarray(am)[x-1, y] != np.asarray(am)[x, y]
                        happiness_i += np.asarray(am)[x+1, y] != np.asarray(am)[x, y]
                        happiness_i += np.asarray(am)[x, y-1] != np.asarray(am)[x, y]
                        happiness_i += np.asarray(am)[x, y+1] != np.asarray(am)[x, y]
                        happiness_count[-1] += [happiness_i]
                        happiness_map[x, y] = happiness_i
                        if happiness_i == 1:
                            string_states["{}, {}".format(x, y)] = {'state': 0 if am[x, y] == 1 else 1,
                                                                    'marker': (x, y),
                                                                    'prev_marker': [(x, y)],
                                                                    'count': 0}
                        elif happiness_i == 0 and am[x, y] == -1:
                            string_states["{}, {}".format(x, y)] = {'state': 6,
                                                                    'marker': (x, y),
                                                                    'prev_marker': [(x, y)],
                                                                    'count': 0}
                    else:
                        happiness_map[x, y] = -1

            # Now go through all potential start points and look for strings
            am = np.asarray(am)
            string_lengths.append([])
            for key, value in string_states.iteritems():
                while value['count'] >= 0:
                    if value['state'] == 0:  # atom on site, happiness 1
                        value['prev_marker'] += [value['marker']]
                        if am[value['marker'][0]-1, value['marker'][1]] == -1:
                            value['marker'] = (value['marker'][0]-1, value['marker'][1])
                        elif am[value['marker'][0]+1, value['marker'][1]] == -1:
                            value['marker'] = (value['marker'][0]+1, value['marker'][1])
                        elif am[value['marker'][0], value['marker'][1]-1] == -1:
                            value['marker'] = (value['marker'][0], value['marker'][1]-1)
                        elif am[value['marker'][0], value['marker'][1]+1] == -1:
                            value['marker'] = (value['marker'][0], value['marker'][1]+1)
                        else:  # if we end up here there is a bug
                            value['count'] = -1000
                            print "BUG"
                            continue
                        if happiness_map[value['marker']] == 2:
                            value['state'] = 2  # hole on site, happiness 2
                            value['count'] += 1
                        elif happiness_map[value['marker']] == 1:
                            value['state'] = 4  # hole on site, happiness 1
                            value['count'] += 1
                        else:
                            value['count'] = -100  # string is invalid
                        continue
                    if value['state'] == 1:  # hole on site, happiness 1
                        value['prev_marker'] += [value['marker']]
                        if am[value['marker'][0]-1, value['marker'][1]] == 1:
                            value['marker'] = (value['marker'][0]-1, value['marker'][1])
                        elif am[value['marker'][0]+1, value['marker'][1]] == 1:
                            value['marker'] = (value['marker'][0]+1, value['marker'][1])
                        elif am[value['marker'][0], value['marker'][1]-1] == 1:
                            value['marker'] = (value['marker'][0], value['marker'][1]-1)
                        elif am[value['marker'][0], value['marker'][1]+1] == 1:
                            value['marker'] = (value['marker'][0], value['marker'][1]+1)
                        else:  # if we end up here there is a bug
                            value['count'] = -1000
                            print "BUG"
                            continue
                        if happiness_map[value['marker']] == 2:
                            value['state'] = 3  # atom on site, happiness 2
                            value['count'] += 1
                        else:
                            value['count'] = -100  # string is invalid
                        continue
                    if value['state'] == 2:  # hole on site, happiness 2
                        if (value['marker'][0]-1, value['marker'][1]) != value['prev_marker'][-1] and \
                           am[value['marker'][0]-1, value['marker'][1]] == 1:
                            value['prev_marker'] += [value['marker']]
                            value['marker'] = (value['marker'][0]-1, value['marker'][1])
                        elif (value['marker'][0]+1, value['marker'][1]) != value['prev_marker'][-1] and \
                             am[value['marker'][0]+1, value['marker'][1]] == 1:
                            value['prev_marker'] += [value['marker']]
                            value['marker'] = (value['marker'][0]+1, value['marker'][1])
                        elif (value['marker'][0], value['marker'][1]-1) != value['prev_marker'][-1] and \
                             am[value['marker'][0], value['marker'][1]-1] == 1:
                            value['prev_marker'] += [value['marker']]
                            value['marker'] = (value['marker'][0], value['marker'][1]-1)
                        elif (value['marker'][0], value['marker'][1]+1) != value['prev_marker'][-1] and \
                             am[value['marker'][0], value['marker'][1]+1] == 1:
                            value['prev_marker'] += [value['marker']]
                            value['marker'] = (value['marker'][0], value['marker'][1]+1)
                        else:  # if we end up here there is a bug
                            value['count'] = -1000
                            print "BUG"
                            continue
                        if happiness_map[value['marker']] == 2:
                            value['state'] = 3  # atom on site, happiness 2
                            value['count'] += 1
                        else:
                            value['count'] = -100  # string is invalid
                        continue
                    if value['state'] == 3:  # atom on site, happiness 2
                        if (value['marker'][0]-1, value['marker'][1]) != value['prev_marker'][-1] and \
                           am[value['marker'][0]-1, value['marker'][1]] == -1:
                            value['prev_marker'] += [value['marker']]
                            value['marker'] = (value['marker'][0]-1, value['marker'][1])
                        elif (value['marker'][0]+1, value['marker'][1]) != value['prev_marker'][-1] and \
                             am[value['marker'][0]+1, value['marker'][1]] == -1:
                            value['prev_marker'] += [value['marker']]
                            value['marker'] = (value['marker'][0]+1, value['marker'][1])
                        elif (value['marker'][0], value['marker'][1]-1) != value['prev_marker'][-1] and \
                             am[value['marker'][0], value['marker'][1]-1] == -1:
                            value['prev_marker'] += [value['marker']]
                            value['marker'] = (value['marker'][0], value['marker'][1]-1)
                        elif (value['marker'][0], value['marker'][1]+1) != value['prev_marker'][-1] and \
                             am[value['marker'][0], value['marker'][1]+1] == -1:
                            value['prev_marker'] += [value['marker']]
                            value['marker'] = (value['marker'][0], value['marker'][1]+1)
                        else:  # if we end up here there is a bug
                            value['count'] = -1000
                            print "BUG"
                            continue
                        if happiness_map[value['marker']] == 2:
                            value['state'] = 2  # hole on site, happiness 2
                            value['count'] += 1
                        elif happiness_map[value['marker']] == 1:
                            value['state'] = 4  # hole on site, happiness 1
                            value['count'] += 1
                        else:
                            value['count'] = -100  # string is invalid
                        continue
                    if value['state'] == 4:  # hole on site, happiness 1 and not the beginning of a string
                        value['count'] = -value['count']  # string completed
                        try:
                            string_states["{}, {}".format(value['marker'][0], value['marker'][1])]['count'] = -10000  # Remove end from list of possible candidate strings
                        except KeyError:
                            pass

                        value['prev_marker'] += [value['marker']]

                        # print "STRING COMPLETE"
                        continue
                    if value['state'] == 6:  # hole on site, happiness 0
                        count3s = []
                        if am[value['marker'][0]-1, value['marker'][1]] == -1 and \
                           happiness_map[value['marker'][0]-1, value['marker'][1]] == 3:
                            count3s += [(value['marker'][0]-1, value['marker'][1])]
                        if am[value['marker'][0]+1, value['marker'][1]] == -1 and \
                           happiness_map[value['marker'][0]+1, value['marker'][1]] == 3:
                            count3s += [(value['marker'][0]+1, value['marker'][1])]
                        if am[value['marker'][0], value['marker'][1]-1] == -1 and \
                           happiness_map[value['marker'][0], value['marker'][1]-1] == 3:
                            count3s += [(value['marker'][0], value['marker'][1]-1)]
                        if am[value['marker'][0], value['marker'][1]+1] == -1 and \
                           happiness_map[value['marker'][0], value['marker'][1]+1] == 3:
                            count3s += [(value['marker'][0], value['marker'][1]+1)]
                        value['prev_marker'] += [value['marker']]
                        value['prev_marker'] += count3s
                        value['count'] += min(len(count3s), 2)
                        value['count'] = -value['count']-7

                if value['count'] > -100:  # Not invalid string or bug in code
                    # string_lengths[-1] += [-value['count']]
                    if value['count'] < -6:  # 0, 30, or 303 string; make count 0 to match Annabelle's code
                        string_lengths[-1] += [0]
                    else:
                        string_lengths[-1] += [-value['count']]

            countedstrings += [len([x for x in string_lengths[-1] if x != 0 and x != 1])]

    numsites = np.sum(mask)

    return string_lengths, [cs/float(numsites) for cs in countedstrings], numsites


def find_strings_experiment(fns, sel=[None], ps=1.0, gb=1.0):
    # Group by doping within bins of gb
    grouped = {}
    for amfn in fns:
        with open(amfn, 'rb') as f:
            ams = pickle.load(f)
        d = float(amfn.split('_')[-1][1:-4])
        print "  Processing ", amfn.split('\\')[-1]

        dround = int(gb * round(d / gb))

        if '{:02d}'.format(dround) in grouped.keys():
            print "repeat doping found: ", dround
            grouped['{:02d}'.format(dround)][0] += list(ams[0])
            grouped['{:02d}'.format(dround)][1] += list(ams[1])
            grouped['{:02d}'.format(dround)][2] += list(ams[2])
        else:
            grouped['{:02d}'.format(dround)] = [list(ams[0]), list(ams[1]), list(ams[2])]

    # For each doping value, cut down to top ps
    selout = []
    selsls = {}
    out = {'doping': [],
           'sc': [],
           'sces': []}

    for key in sorted(grouped.iterkeys()):
        item = grouped[key]

        ams = postselect(list(item[1])+list(item[2]), ps)

        sls, css, ns = find_strings(ams)

        if int(key) in sel:
            print "exporting histogram data", int(key)
            sl_means = [np.histogram(sl, bins=8, range=(0, 8))[0] for sl in sls]
            selsls['lengthFCS_mean'] = np.mean(sl_means, axis=0)/ns
            selsls['lengthFCS_std'] = np.std(sl_means, axis=0)/ns/np.sqrt(np.array(sl_means).shape[0])
            selout.append(selsls.copy())

        cs_means = [np.mean(cs) for cs in css]
        out['doping'].append(int(key))
        out['sc'].append(np.mean(cs_means))
        out['sces'].append(np.std(cs_means)/np.sqrt(len(cs_means)))

    return out, selout


if __name__ == '__main__':
    # parameters go here
    amdir = 'V:\Paper Data\AFM_StringTheory\\'
    ps = 0.6

    expfns = [fn for fn in os.listdir(amdir + "Snapshots_experiment\\") if 'cold_D' in fn]
    sel = [0, 10]
    out, selout = find_strings_experiment([amdir + "Snapshots_experiment\\" + fn for fn in expfns],
                                          sel=sel, ps=ps, gb=2.0)
    with open(amdir+'FCS\cold_happiness_avgs.pkl', 'wb') as f:
        pickle.dump(out, f)
    for dat, d in zip(selout, sel):
        print "saving experiment hist", d
        with open(amdir+'FCS\cold_D{}_happiness_hist.pkl'.format(d), 'wb') as f:
            pickle.dump(dat, f)

    stringfns = [fn for fn in os.listdir(amdir + "Snapshots_strings\\") if 'gst0.60J_cold' in fn]
    sel = [10]
    out, selout = find_strings_experiment([amdir + "Snapshots_strings\\" + fn for fn in stringfns],
                                          sel=sel, ps=ps, gb=2.0)
    with open(amdir+'FCS\gst0.60J_cold_happiness_avgs.pkl', 'wb') as f:
        pickle.dump(out, f)
    for dat, d in zip(selout, sel):
        print "saving string hist", d
        with open(amdir+'FCS\gst0.60J_cold_D{}_happiness_hist.pkl'.format(d), 'wb') as f:
            pickle.dump(dat, f)

    sprinkfns = [fn for fn in os.listdir(amdir + "Snapshots_sprinkled\\") if 'peak0_cold' in fn]
    sel = [10]
    out, selout = find_strings_experiment([amdir + "Snapshots_sprinkled\\" + fn for fn in sprinkfns], sel=sel, ps=ps, gb=2.0)
    with open(amdir+'FCS\peak0_cold_happiness_avgs.pkl', 'wb') as f:
        pickle.dump(out, f)
    for dat, d in zip(selout, sel):
        print 'saving sprinkled hist', d
        with open(amdir+'FCS\peak0_cold_D{}_happiness_hist.pkl'.format(d), 'wb') as f:
            pickle.dump(dat, f)


    piffns = [fn for fn in os.listdir(amdir + "Snapshots_piflux\\")]
    sel = [10]
    out, selout = find_strings_experiment([amdir + "Snapshots_piflux\\" + fn for fn in piffns], sel=sel, ps=ps, gb=2.0)
    with open(amdir+'FCS\piflux_cold_happiness_avgs.pkl', 'wb') as f:
        pickle.dump(out, f)
    for dat, d in zip(selout, sel):
        print "saving piflux hist", d
        with open(amdir+'FCS\piflux_cold_D{}_happiness_hist.pkl'.format(d), 'wb') as f:
            pickle.dump(dat, f)

